#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File      :   STSViT.py
@Time      :   2024/02/26 20:18:25
@Author    :   CSDN迪菲赫尔曼 
@Version   :   1.0
@Reference :   https://blog.csdn.net/weixin_43694096
@Desc      :   None
"""


import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = "StokenAttention"


class Unfold(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()

        self.kernel_size = kernel_size

        weights = torch.eye(kernel_size**2)
        weights = weights.reshape(kernel_size**2, 1, kernel_size, kernel_size)
        self.weights = nn.Parameter(weights, requires_grad=False)

    def forward(self, x):
        b, c, h, w = x.shape
        x = F.conv2d(
            x.reshape(b * c, 1, h, w),
            self.weights,
            stride=1,
            padding=self.kernel_size // 2,
        )
        return x.reshape(b, c * 9, h * w)


class Fold(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()

        self.kernel_size = kernel_size

        weights = torch.eye(kernel_size**2)
        weights = weights.reshape(kernel_size**2, 1, kernel_size, kernel_size)
        self.weights = nn.Parameter(weights, requires_grad=False)

    def forward(self, x):
        b, _, h, w = x.shape
        x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2)
        return x


class StAttention(nn.Module):
    def __init__(
        self,
        dim,
        window_size=None,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.window_size = window_size

        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(dim, dim, 1)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W

        q, k, v = (
            self.qkv(x)
            .reshape(B, self.num_heads, C // self.num_heads * 3, N)
            .chunk(3, dim=2)
        )  # (B, num_heads, head_dim, N)

        attn = (k.transpose(-1, -2) @ q) * self.scale

        attn = attn.softmax(dim=-2)  # (B, h, N, N)
        attn = self.attn_drop(attn)

        x = (v @ attn).reshape(B, C, H, W)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class StokenAttention(nn.Module):
    def __init__(
        self,
        dim,
        stoken_size=[8, 8],
        n_iter=1,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()

        self.n_iter = n_iter
        self.stoken_size = stoken_size

        self.scale = dim**-0.5

        self.unfold = Unfold(3)
        self.fold = Fold(3)

        self.stoken_refine = StAttention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
        )

    def stoken_forward(self, x):
        """
        x: (B, C, H, W)
        """
        B, C, H0, W0 = x.shape
        h, w = self.stoken_size

        pad_l = pad_t = 0
        pad_r = (w - W0 % w) % w
        pad_b = (h - H0 % h) % h
        if pad_r > 0 or pad_b > 0:
            x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))

        _, _, H, W = x.shape

        hh, ww = H // h, W // w

        stoken_features = F.adaptive_avg_pool2d(x, (hh, ww))  # (B, C, hh, ww)

        pixel_features = (
            x.reshape(B, C, hh, h, ww, w)
            .permute(0, 2, 4, 3, 5, 1)
            .reshape(B, hh * ww, h * w, C)
        )

        with torch.no_grad():
            for idx in range(self.n_iter):
                stoken_features = self.unfold(stoken_features)  # (B, C*9, hh*ww)
                stoken_features = stoken_features.transpose(1, 2).reshape(
                    B, hh * ww, C, 9
                )
                affinity_matrix = (
                    pixel_features @ stoken_features * self.scale
                )  # (B, hh*ww, h*w, 9)

                affinity_matrix = affinity_matrix.softmax(-1)  # (B, hh*ww, h*w, 9)

                affinity_matrix_sum = (
                    affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
                )

                affinity_matrix_sum = self.fold(affinity_matrix_sum)
                if idx < self.n_iter - 1:
                    stoken_features = (
                        pixel_features.transpose(-1, -2) @ affinity_matrix
                    )  # (B, hh*ww, C, 9)

                    stoken_features = self.fold(
                        stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)
                    ).reshape(B, C, hh, ww)

                    stoken_features = stoken_features / (
                        affinity_matrix_sum + 1e-12
                    )  # (B, C, hh, ww)

        stoken_features = (
            pixel_features.transpose(-1, -2) @ affinity_matrix
        )  # (B, hh*ww, C, 9)

        stoken_features = self.fold(
            stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)
        ).reshape(B, C, hh, ww)

        stoken_features = stoken_features / (
            affinity_matrix_sum.detach() + 1e-12
        )  # (B, C, hh, ww)

        stoken_features = self.stoken_refine(stoken_features)

        stoken_features = self.unfold(stoken_features)  # (B, C*9, hh*ww)
        stoken_features = stoken_features.transpose(1, 2).reshape(
            B, hh * ww, C, 9
        )  # (B, hh*ww, C, 9)

        pixel_features = stoken_features @ affinity_matrix.transpose(
            -1, -2
        )  # (B, hh*ww, C, h*w)

        pixel_features = (
            pixel_features.reshape(B, hh, ww, C, h, w)
            .permute(0, 3, 1, 4, 2, 5)
            .reshape(B, C, H, W)
        )

        if pad_r > 0 or pad_b > 0:
            pixel_features = pixel_features[:, :, :H0, :W0]

        return pixel_features

    def direct_forward(self, x):
        B, C, H, W = x.shape
        stoken_features = x
        stoken_features = self.stoken_refine(stoken_features)
        return stoken_features

    def forward(self, x):
        if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
            return self.stoken_forward(x)
        else:
            return self.direct_forward(x)
